/*
 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

namespace ark::libllvmbackend {

void RegisterPasses(llvm::PassInstrumentationCallbacks &passInstrumentation)
{
    auto registerName = [&passInstrumentation](llvm::StringRef className, llvm::StringRef passName) {
        passInstrumentation.addClassToPassName(className, passName);
    };

% PassRegistry::llvm_passes.each do |pass|
    registerName("<%= PassRegistry::passes_namespace %>::<%= pass.name %>", <%= PassRegistry::passes_namespace %>::<%= pass.name %>::ARG_NAME);
% end
}

using StringRef = llvm::StringRef;
using LLVMCompilerOptions = ark::llvmbackend::LLVMCompilerOptions;
using PipelineElements = llvm::ArrayRef<llvm::PassBuilder::PipelineElement>;

struct PassParser
{
public:
    explicit PassParser(ark::llvmbackend::LLVMArkInterface *arkInterface) : arkInterface_{arkInterface}
    {
    }

    bool ParseModulePasses(StringRef name, llvm::ModulePassManager &modulePm, const LLVMCompilerOptions &options);
    bool ParseFunctionPasses(StringRef name, llvm::FunctionPassManager &functionPm, const LLVMCompilerOptions &options);
    bool ParseSCCPasses(StringRef name, llvm::CGSCCPassManager &sccPm, const LLVMCompilerOptions &options);
    bool ParseLoopPasses(StringRef name, llvm::LoopPassManager &sccPm, const LLVMCompilerOptions &options);

    void RegisterParserCallbacks(llvm::PassBuilder &builder, const ark::llvmbackend::LLVMCompilerOptions &options)
    {
        builder.registerPipelineParsingCallback(
        [&](StringRef name, llvm::ModulePassManager &modulePm, PipelineElements /*unused*/) -> bool {
            return ParseModulePasses(name, modulePm, options);
        });
        builder.registerPipelineParsingCallback(
        [&](StringRef name, llvm::FunctionPassManager &functionPm, PipelineElements /*unused*/) -> bool {
            return ParseFunctionPasses(name, functionPm, options);
        });
        builder.registerPipelineParsingCallback(
        [&](StringRef name, llvm::CGSCCPassManager &sccPm, PipelineElements /*unused*/) -> bool {
            return ParseSCCPasses(name, sccPm, options);
        });
        builder.registerPipelineParsingCallback(
        [&](StringRef name, llvm::LoopPassManager &sccPm, PipelineElements /*unused*/) -> bool {
            return ParseLoopPasses(name, sccPm, options);
        });
    }
private:
    ark::llvmbackend::LLVMArkInterface *arkInterface_;
};

bool PassParser::ParseModulePasses(StringRef name, llvm::ModulePassManager &modulePm, const LLVMCompilerOptions &options)
{
    auto &pm = modulePm;
    namespace pass = <%= PassRegistry::passes_namespace %>;

% PassRegistry::llvm_passes.select{|p| p.type.include? 'module'}.each do |pass|
    if (name.equals(pass::<%= pass.name %>::ARG_NAME)) {
        if (pass::<%= pass.name %>::ShouldInsert(&options)) {
% if pass.setup == 'default'
            pm.addPass(pass::<%= pass.name %>());
% else
            pm.addPass(pass::<%= pass.name %>::Create(arkInterface_, &options));
% end
#ifndef NDEBUG
            pm.addPass(llvm::VerifierPass());
#endif
        }
        return true;
    }
% end

    return false;
}

bool PassParser::ParseFunctionPasses(StringRef name, llvm::FunctionPassManager &functionPm, const LLVMCompilerOptions &options)
{
    auto &pm = functionPm;
    namespace pass = <%= PassRegistry::passes_namespace %>;

% PassRegistry::llvm_passes.select{|p| p.type.include? 'function'}.each do |pass|
    if (name.equals(pass::<%= pass.name %>::ARG_NAME)) {
        if (pass::<%= pass.name %>::ShouldInsert(&options)) {
% if pass.setup == 'default'
            pm.addPass(pass::<%= pass.name %>());
% else
            pm.addPass(pass::<%= pass.name %>::Create(arkInterface_, &options));
% end
#ifndef NDEBUG
            pm.addPass(llvm::VerifierPass());
#endif
        }
        return true;
    }
% end

    return false;
}

bool PassParser::ParseSCCPasses(StringRef name, llvm::CGSCCPassManager &sccPm, const LLVMCompilerOptions &options)
{
    auto &pm = sccPm;
    namespace pass = <%= PassRegistry::passes_namespace %>;

% PassRegistry::llvm_passes.select{|p| p.type.include? 'scc'}.each do |pass|
    if (name.equals(pass::<%= pass.name %>::ARG_NAME)) {
        if (pass::<%= pass.name %>::ShouldInsert(&options)) {
% if pass.setup == 'default'
            pm.addPass(pass::<%= pass.name %>());
% else
            pm.addPass(pass::<%= pass.name %>::Create(arkInterface_, &options));
% end
        }
        return true;
    }
% end

    return false;
}

bool PassParser::ParseLoopPasses(StringRef name, llvm::LoopPassManager &loopPm, const LLVMCompilerOptions &options)
{
    auto &pm = loopPm;
    namespace pass = <%= PassRegistry::passes_namespace %>;

% PassRegistry::llvm_passes.select{|p| p.type.include? 'loop'}.each do |pass|
    if (name.equals(pass::<%= pass.name %>::ARG_NAME)) {
        if (pass::<%= pass.name %>::ShouldInsert(&options)) {
% if pass.setup == 'default'
            pm.addPass(pass::<%= pass.name %>());
% else
            pm.addPass(pass::<%= pass.name %>::Create(arkInterface_, &options));
% end
        }
        return true;
    }
% end

    return false;
}

}  // namespace ark::libllvmbackend
